{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 539,
     "status": "ok",
     "timestamp": 1668226462186,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": 0
    },
    "id": "DxYJWBWWIswz",
    "outputId": "6b5123dd-d4f5-43e8-b8f4-3be029d8594c"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    " \n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "def llm(prompt, stop=[\"\\n\"]):\n",
    "    response = openai.Completion.create(\n",
    "      model=\"text-davinci-002\",\n",
    "      prompt=prompt,\n",
    "      temperature=0,\n",
    "      max_tokens=100,\n",
    "      top_p=1,\n",
    "      frequency_penalty=0.0,\n",
    "      presence_penalty=0.0,\n",
    "      stop=stop\n",
    "    )\n",
    "    return response[\"choices\"][0][\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7bn-tOHILXvQ"
   },
   "outputs": [],
   "source": [
    "import yaml\n",
    "import alfworld\n",
    "import alfworld.agents.environment\n",
    "with open('base_config.yaml') as reader:\n",
    "    config = yaml.safe_load(reader)\n",
    "    \n",
    "split = \"eval_out_of_distribution\"\n",
    "\n",
    "env = getattr(alfworld.agents.environment, config[\"env\"][\"type\"])(config, train_eval=split)\n",
    "env = env.init_env(batch_size=1)\n",
    "\n",
    "def process_ob(ob):\n",
    "    if ob.startswith('You arrive at loc '):\n",
    "        ob = ob[ob.find('. ')+2:]    \n",
    "    return ob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dWrHRouPIqwC"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "folder = './prompts/'\n",
    "prompt_file = 'alfworld_3prompts.json'\n",
    "with open(folder + prompt_file, 'r') as f:\n",
    "    d = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O8jPzg9nOGG8"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "def alfworld_run(prompt, to_print=True, ob=''):\n",
    "    init_prompt = prompt + ob + '\\n>'\n",
    "    prompt = ''\n",
    "    if to_print:\n",
    "        print(ob)\n",
    "        sys.stdout.flush()\n",
    "    for i in range(1, 50):\n",
    "        action = llm(init_prompt + prompt, stop=['\\n']).strip()\n",
    "        observation, reward, done, info = env.step([action])\n",
    "        observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]\n",
    "        if action.startswith('think:'):\n",
    "            observation = 'OK.'\n",
    "        if to_print:\n",
    "            print(f'Act {i}: {action}\\nObs {i}: {observation}')\n",
    "            sys.stdout.flush()\n",
    "        prompt += f' {action}\\n{observation}\\n>'\n",
    "        if done:\n",
    "            return reward\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_O1MRsQa83wL",
    "outputId": "88b5e8ed-c707-43ec-a442-8e557559e68c"
   },
   "outputs": [],
   "source": [
    "prefixes = {\n",
    "    'pick_and_place': 'put',\n",
    "    'pick_clean_then_place': 'clean',\n",
    "    'pick_heat_then_place': 'heat',\n",
    "    'pick_cool_then_place': 'cool',\n",
    "    'look_at_obj': 'examine',\n",
    "    'pick_two_obj': 'puttwo'\n",
    "}\n",
    "cnts = [0] * 6\n",
    "rs = [0] * 6\n",
    "\n",
    "for _ in range(134):\n",
    "    ob, info = env.reset()\n",
    "    ob = '\\n'.join(ob[0].split('\\n\\n')[1:])\n",
    "    name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])\n",
    "    print(name)\n",
    "    for i, (k, v) in enumerate(prefixes.items()):\n",
    "        if name.startswith(k):\n",
    "            prompt = 'Interact with a household to solve a task. Here are two examples.\\n' + d[f'react_{v}_1'] + d[f'react_{v}_0'] + '\\nHere is the task.\\n'\n",
    "            print(k, v)\n",
    "            r = alfworld_run(prompt, ob=ob)\n",
    "            rs[i] += r\n",
    "            cnts[i] += 1\n",
    "            break\n",
    "    print(_+1, 'r', r, 'rs', rs, 'cnts', cnts, 'sum(rs)/sum(cnts)', sum(rs) / sum(cnts))\n",
    "    print('------------\\n')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
